#include "core_functions/scalar/list_functions.hpp"
#include "core_functions/aggregate/nested_functions.hpp"
#include "duckdb/catalog/catalog.hpp"
#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp"
#include "duckdb/execution/expression_executor.hpp"
#include "duckdb/function/scalar/nested_functions.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "duckdb/planner/expression/bound_constant_expression.hpp"
#include "duckdb/planner/expression/bound_function_expression.hpp"
#include "duckdb/planner/expression/bound_cast_expression.hpp"
#include "duckdb/planner/expression_binder.hpp"
#include "duckdb/function/function_binder.hpp"
#include "duckdb/function/create_sort_key.hpp"
#include "duckdb/common/owning_string_map.hpp"

namespace duckdb {

namespace {

struct ListAggregatesLocalState : public FunctionLocalState {
	explicit ListAggregatesLocalState(Allocator &allocator) : arena_allocator(allocator) {
	}

	ArenaAllocator arena_allocator;
};

unique_ptr<FunctionLocalState> ListAggregatesInitLocalState(ExpressionState &state, const BoundFunctionExpression &expr,
                                                            FunctionData *bind_data) {
	return make_uniq<ListAggregatesLocalState>(BufferAllocator::Get(state.GetContext()));
}
// FIXME: benchmark the use of simple_update against using update (if applicable)

unique_ptr<FunctionData> ListAggregatesBindFailure(ScalarFunction &bound_function) {
	bound_function.arguments[0] = LogicalType::SQLNULL;
	bound_function.SetReturnType(LogicalType::SQLNULL);
	return make_uniq<VariableReturnBindData>(LogicalType::SQLNULL);
}

struct ListAggregatesBindData : public FunctionData {
	ListAggregatesBindData(const LogicalType &stype_p, unique_ptr<Expression> aggr_expr_p);
	~ListAggregatesBindData() override;

	LogicalType stype;
	unique_ptr<Expression> aggr_expr;

	unique_ptr<FunctionData> Copy() const override {
		return make_uniq<ListAggregatesBindData>(stype, aggr_expr->Copy());
	}

	bool Equals(const FunctionData &other_p) const override {
		auto &other = other_p.Cast<ListAggregatesBindData>();
		return stype == other.stype && aggr_expr->Equals(*other.aggr_expr);
	}
	void Serialize(Serializer &serializer) const {
		serializer.WriteProperty(1, "stype", stype);
		serializer.WriteProperty(2, "aggr_expr", aggr_expr);
	}
	static unique_ptr<ListAggregatesBindData> Deserialize(Deserializer &deserializer) {
		auto stype = deserializer.ReadProperty<LogicalType>(1, "stype");
		auto aggr_expr = deserializer.ReadProperty<unique_ptr<Expression>>(2, "aggr_expr");
		auto result = make_uniq<ListAggregatesBindData>(std::move(stype), std::move(aggr_expr));
		return result;
	}

	static void SerializeFunction(Serializer &serializer, const optional_ptr<FunctionData> bind_data_p,
	                              const ScalarFunction &function) {
		auto bind_data = dynamic_cast<const ListAggregatesBindData *>(bind_data_p.get());
		serializer.WritePropertyWithDefault(100, "bind_data", bind_data, (const ListAggregatesBindData *)nullptr);
	}

	static unique_ptr<FunctionData> DeserializeFunction(Deserializer &deserializer, ScalarFunction &bound_function) {
		auto result = deserializer.ReadPropertyWithExplicitDefault<unique_ptr<ListAggregatesBindData>>(
		    100, "bind_data", unique_ptr<ListAggregatesBindData>(nullptr));
		if (!result) {
			return ListAggregatesBindFailure(bound_function);
		}
		return std::move(result);
	}
};

ListAggregatesBindData::ListAggregatesBindData(const LogicalType &stype_p, unique_ptr<Expression> aggr_expr_p)
    : stype(stype_p), aggr_expr(std::move(aggr_expr_p)) {
}

ListAggregatesBindData::~ListAggregatesBindData() {
}

struct StateVector {
	StateVector(idx_t count_p, unique_ptr<Expression> aggr_expr_p)
	    : count(count_p), aggr_expr(std::move(aggr_expr_p)), state_vector(Vector(LogicalType::POINTER, count_p)) {
	}

	~StateVector() { // NOLINT
		// destroy objects within the aggregate states
		auto &aggr = aggr_expr->Cast<BoundAggregateExpression>();
		if (aggr.function.HasStateDestructorCallback()) {
			ArenaAllocator allocator(Allocator::DefaultAllocator());
			AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator);
			aggr.function.GetStateDestructorCallback()(state_vector, aggr_input_data, count);
		}
	}

	idx_t count;
	unique_ptr<Expression> aggr_expr;
	Vector state_vector;
};

struct FinalizeValueFunctor {
	template <class T>
	static void HistogramFinalize(T value, Vector &result, idx_t offset) {
		FlatVector::GetData<T>(result)[offset] = value;
	}
};

struct FinalizeStringValueFunctor {
	template <class T>
	static void HistogramFinalize(T value, Vector &result, idx_t offset) {
		FlatVector::GetData<string_t>(result)[offset] = StringVector::AddStringOrBlob(result, value);
	}
};

struct FinalizeGenericValueFunctor {
	template <class T>
	static void HistogramFinalize(T value, Vector &result, idx_t offset) {
		CreateSortKeyHelpers::DecodeSortKey(value, result, offset,
		                                    OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST));
	}
};

struct AggregateFunctor {
	template <class OP, class T, class MAP_TYPE = unordered_map<T, idx_t>>
	static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) {
	}
};

struct DistinctFunctor {
	template <class OP, class T, class MAP_TYPE = unordered_map<T, idx_t>>
	static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) {
		UnifiedVectorFormat sdata;
		state_vector.ToUnifiedFormat(count, sdata);
		auto states = UnifiedVectorFormat::GetData<HistogramAggState<T, MAP_TYPE> *>(sdata);

		auto old_len = ListVector::GetListSize(result);
		idx_t new_entries = 0;
		// figure out how much space we need
		for (idx_t i = 0; i < count; i++) {
			auto &state = *states[sdata.sel->get_index(i)];
			if (!state.hist) {
				continue;
			}
			new_entries += state.hist->size();
		}
		// reserve space in the list vector
		ListVector::Reserve(result, old_len + new_entries);
		auto &child_elements = ListVector::GetEntry(result);
		auto list_entries = FlatVector::GetData<list_entry_t>(result);

		idx_t current_offset = old_len;
		for (idx_t i = 0; i < count; i++) {
			const auto rid = i;
			auto &state = *states[sdata.sel->get_index(i)];
			auto &list_entry = list_entries[rid];
			list_entry.offset = current_offset;
			if (!state.hist) {
				list_entry.length = 0;
				continue;
			}

			for (auto &entry : *state.hist) {
				OP::template HistogramFinalize<T>(entry.first, child_elements, current_offset);
				current_offset++;
			}
			list_entry.length = current_offset - list_entry.offset;
		}
		D_ASSERT(current_offset == old_len + new_entries);
		ListVector::SetListSize(result, current_offset);
		result.Verify(count);
	}
};

struct UniqueFunctor {
	template <class OP, class T, class MAP_TYPE = unordered_map<T, idx_t>>
	static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) {
		UnifiedVectorFormat sdata;
		state_vector.ToUnifiedFormat(count, sdata);
		auto states = UnifiedVectorFormat::GetData<HistogramAggState<T, MAP_TYPE> *>(sdata);

		auto result_data = FlatVector::GetData<uint64_t>(result);
		for (idx_t i = 0; i < count; i++) {
			auto state = states[sdata.sel->get_index(i)];

			if (!state->hist) {
				result_data[i] = 0;
				continue;
			}
			result_data[i] = state->hist->size();
		}
		result.Verify(count);
	}
};

template <class FUNCTION_FUNCTOR, bool IS_AGGR = false>
void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vector &result) {
	auto count = args.size();
	Vector &lists = args.data[0];

	// set the result vector
	result.SetVectorType(VectorType::FLAT_VECTOR);
	auto &result_validity = FlatVector::Validity(result);

	if (lists.GetType().id() == LogicalTypeId::SQLNULL) {
		result.SetVectorType(VectorType::CONSTANT_VECTOR);
		ConstantVector::SetNull(result, true);
		return;
	}

	// get the aggregate function
	auto &func_expr = state.expr.Cast<BoundFunctionExpression>();
	auto &info = func_expr.bind_info->Cast<ListAggregatesBindData>();
	auto &aggr = info.aggr_expr->Cast<BoundAggregateExpression>();
	auto &allocator = ExecuteFunctionState::GetFunctionState(state)->Cast<ListAggregatesLocalState>().arena_allocator;
	allocator.Reset();
	AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator);

	D_ASSERT(aggr.function.HasStateUpdateCallback());

	auto lists_size = ListVector::GetListSize(lists);
	auto &child_vector = ListVector::GetEntry(lists);
	child_vector.Flatten(lists_size);

	UnifiedVectorFormat child_data;
	child_vector.ToUnifiedFormat(lists_size, child_data);

	UnifiedVectorFormat lists_data;
	lists.ToUnifiedFormat(count, lists_data);
	auto list_entries = UnifiedVectorFormat::GetData<list_entry_t>(lists_data);

	// state_buffer holds the state for each list of this chunk
	idx_t size = aggr.function.GetStateSizeCallback()(aggr.function);
	auto state_buffer = make_unsafe_uniq_array_uninitialized<data_t>(size * count);

	// state vector for initialize and finalize
	StateVector state_vector(count, info.aggr_expr->Copy());
	auto states = FlatVector::GetData<data_ptr_t>(state_vector.state_vector);

	// state vector of STANDARD_VECTOR_SIZE holds the pointers to the states
	Vector state_vector_update = Vector(LogicalType::POINTER);
	auto states_update = FlatVector::GetData<data_ptr_t>(state_vector_update);

	// selection vector pointing to the data
	SelectionVector sel_vector(STANDARD_VECTOR_SIZE);
	idx_t states_idx = 0;

	for (idx_t i = 0; i < count; i++) {
		// initialize the state for this list
		auto state_ptr = state_buffer.get() + size * i;
		states[i] = state_ptr;
		aggr.function.GetStateInitCallback()(aggr.function, states[i]);

		auto lists_index = lists_data.sel->get_index(i);
		const auto &list_entry = list_entries[lists_index];

		// nothing to do for this list
		if (!lists_data.validity.RowIsValid(lists_index)) {
			result_validity.SetInvalid(i);
			continue;
		}

		// skip empty list
		if (list_entry.length == 0) {
			continue;
		}

		for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) {
			// states vector is full, update
			if (states_idx == STANDARD_VECTOR_SIZE) {
				// update the aggregate state(s)
				Vector slice(child_vector, sel_vector, states_idx);
				aggr.function.GetStateUpdateCallback()(&slice, aggr_input_data, 1, state_vector_update, states_idx);

				// reset values
				states_idx = 0;
			}

			auto source_idx = child_data.sel->get_index(list_entry.offset + child_idx);
			sel_vector.set_index(states_idx, source_idx);
			states_update[states_idx] = state_ptr;
			states_idx++;
		}
	}

	// update the remaining elements of the last list(s)
	if (states_idx != 0) {
		Vector slice(child_vector, sel_vector, states_idx);
		aggr.function.GetStateUpdateCallback()(&slice, aggr_input_data, 1, state_vector_update, states_idx);
	}

	if (IS_AGGR) {
		// finalize all the aggregate states
		aggr.function.GetStateFinalizeCallback()(state_vector.state_vector, aggr_input_data, result, count, 0);

	} else {
		// finalize manually to use the map
		D_ASSERT(aggr.function.arguments.size() == 1);
		auto key_type = aggr.function.arguments[0];

		switch (key_type.InternalType()) {
#ifndef DUCKDB_SMALLER_BINARY
		case PhysicalType::BOOL:
			FUNCTION_FUNCTOR::template ListExecuteFunction<FinalizeValueFunctor, bool>(
			    result, state_vector.state_vector, count);
			break;
		case PhysicalType::UINT8:
			FUNCTION_FUNCTOR::template ListExecuteFunction<FinalizeValueFunctor, uint8_t>(
			    result, state_vector.state_vector, count);
			break;
		case PhysicalType::UINT16:
			FUNCTION_FUNCTOR::template ListExecuteFunction<FinalizeValueFunctor, uint16_t>(
			    result, state_vector.state_vector, count);
			break;
		case PhysicalType::UINT32:
			FUNCTION_FUNCTOR::template ListExecuteFunction<FinalizeValueFunctor, uint32_t>(
			    result, state_vector.state_vector, count);
			break;
		case PhysicalType::UINT64:
			FUNCTION_FUNCTOR::template ListExecuteFunction<FinalizeValueFunctor, uint64_t>(
			    result, state_vector.state_vector, count);
			break;
		case PhysicalType::INT8:
			FUNCTION_FUNCTOR::template ListExecuteFunction<FinalizeValueFunctor, int8_t>(
			    result, state_vector.state_vector, count);
			break;
		case PhysicalType::INT16:
			FUNCTION_FUNCTOR::template ListExecuteFunction<FinalizeValueFunctor, int16_t>(
			    result, state_vector.state_vector, count);
			break;
		case PhysicalType::INT32:
			FUNCTION_FUNCTOR::template ListExecuteFunction<FinalizeValueFunctor, int32_t>(
			    result, state_vector.state_vector, count);
			break;
		case PhysicalType::INT64:
			FUNCTION_FUNCTOR::template ListExecuteFunction<FinalizeValueFunctor, int64_t>(
			    result, state_vector.state_vector, count);
			break;
		case PhysicalType::FLOAT:
			FUNCTION_FUNCTOR::template ListExecuteFunction<FinalizeValueFunctor, float>(
			    result, state_vector.state_vector, count);
			break;
		case PhysicalType::DOUBLE:
			FUNCTION_FUNCTOR::template ListExecuteFunction<FinalizeValueFunctor, double>(
			    result, state_vector.state_vector, count);
			break;
		case PhysicalType::VARCHAR:
			FUNCTION_FUNCTOR::template ListExecuteFunction<FinalizeStringValueFunctor, string_t,
			                                               OwningStringMap<idx_t>>(result, state_vector.state_vector,
			                                                                       count);
			break;
#endif
		default:
			FUNCTION_FUNCTOR::template ListExecuteFunction<FinalizeGenericValueFunctor, string_t,
			                                               OwningStringMap<idx_t>>(result, state_vector.state_vector,
			                                                                       count);
			break;
		}
	}

	if (args.AllConstant()) {
		result.SetVectorType(VectorType::CONSTANT_VECTOR);
	}
}

void ListAggregateFunction(DataChunk &args, ExpressionState &state, Vector &result) {
	D_ASSERT(args.ColumnCount() >= 2);
	ListAggregatesFunction<AggregateFunctor, true>(args, state, result);
}

void ListDistinctFunction(DataChunk &args, ExpressionState &state, Vector &result) {
	D_ASSERT(args.ColumnCount() == 1);
	ListAggregatesFunction<DistinctFunctor>(args, state, result);
}

void ListUniqueFunction(DataChunk &args, ExpressionState &state, Vector &result) {
	D_ASSERT(args.ColumnCount() == 1);
	ListAggregatesFunction<UniqueFunctor>(args, state, result);
}

template <bool IS_AGGR = false>
unique_ptr<FunctionData>
ListAggregatesBindFunction(ClientContext &context, ScalarFunction &bound_function, const LogicalType &list_child_type,
                           AggregateFunction &aggr_function, vector<unique_ptr<Expression>> &arguments) {
	// create the child expression and its type
	vector<unique_ptr<Expression>> children;
	auto expr = make_uniq<BoundConstantExpression>(Value(list_child_type));
	children.push_back(std::move(expr));
	// push any extra arguments into the list aggregate bind
	if (arguments.size() > 2) {
		for (idx_t i = 2; i < arguments.size(); i++) {
			children.push_back(std::move(arguments[i]));
		}
		arguments.resize(2);
	}

	FunctionBinder function_binder(context);
	auto bound_aggr_function = function_binder.BindAggregateFunction(aggr_function, std::move(children));
	bound_function.arguments[0] = LogicalType::LIST(bound_aggr_function->function.arguments[0]);

	if (IS_AGGR) {
		bound_function.SetReturnType(bound_aggr_function->function.GetReturnType());
	}
	// check if the aggregate function consumed all the extra input arguments
	if (bound_aggr_function->children.size() > 1) {
		throw InvalidInputException(
		    "Aggregate function %s is not supported for list_aggr: extra arguments were not removed during bind",
		    bound_aggr_function->ToString());
	}

	return make_uniq<ListAggregatesBindData>(bound_function.GetReturnType(), std::move(bound_aggr_function));
}

template <bool IS_AGGR = false>
unique_ptr<FunctionData> ListAggregatesBind(ClientContext &context, ScalarFunction &bound_function,
                                            vector<unique_ptr<Expression>> &arguments) {
	arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0]));

	if (arguments[0]->return_type.id() == LogicalTypeId::SQLNULL) {
		return ListAggregatesBindFailure(bound_function);
	}

	bool is_parameter = arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN;
	LogicalType child_type;
	if (is_parameter) {
		child_type = LogicalType::ANY;
	} else if (arguments[0]->return_type.id() == LogicalTypeId::LIST ||
	           arguments[0]->return_type.id() == LogicalTypeId::MAP) {
		child_type = ListType::GetChildType(arguments[0]->return_type);
	} else {
		// Unreachable
		throw InvalidInputException("First argument of list aggregate must be a list, map or array");
	}

	string function_name = "histogram";
	if (IS_AGGR) { // get the name of the aggregate function
		if (!arguments[1]->IsFoldable()) {
			throw InvalidInputException("Aggregate function name must be a constant");
		}
		// get the function name
		Value function_value = ExpressionExecutor::EvaluateScalar(context, *arguments[1]);
		function_name = function_value.ToString();
	}

	// look up the aggregate function in the catalog
	auto &func = Catalog::GetSystemCatalog(context).GetEntry<AggregateFunctionCatalogEntry>(context, DEFAULT_SCHEMA,
	                                                                                        function_name);
	D_ASSERT(func.type == CatalogType::AGGREGATE_FUNCTION_ENTRY);

	if (is_parameter) {
		bound_function.arguments[0] = LogicalTypeId::UNKNOWN;
		bound_function.SetReturnType(LogicalType::SQLNULL);
		return nullptr;
	}

	// find a matching aggregate function
	ErrorData error;
	vector<LogicalType> types;
	types.push_back(child_type);
	// push any extra arguments into the type list
	for (idx_t i = 2; i < arguments.size(); i++) {
		types.push_back(arguments[i]->return_type);
	}

	FunctionBinder function_binder(context);
	auto best_function_idx = function_binder.BindFunction(func.name, func.functions, types, error);
	if (!best_function_idx.IsValid()) {
		throw BinderException("No matching aggregate function\n%s", error.Message());
	}

	// found a matching function, bind it as an aggregate
	auto best_function = func.functions.GetFunctionByOffset(best_function_idx.GetIndex());
	if (IS_AGGR) {
		bound_function.SetErrorMode(best_function.GetErrorMode());
		return ListAggregatesBindFunction<IS_AGGR>(context, bound_function, child_type, best_function, arguments);
	}

	// create the unordered map histogram function
	D_ASSERT(best_function.arguments.size() == 1);
	auto aggr_function = HistogramFun::GetHistogramUnorderedMap(child_type);
	return ListAggregatesBindFunction<IS_AGGR>(context, bound_function, child_type, aggr_function, arguments);
}

unique_ptr<FunctionData> ListAggregateBind(ClientContext &context, ScalarFunction &bound_function,
                                           vector<unique_ptr<Expression>> &arguments) {
	// the list column and the name of the aggregate function
	D_ASSERT(bound_function.arguments.size() >= 2);
	D_ASSERT(arguments.size() >= 2);

	return ListAggregatesBind<true>(context, bound_function, arguments);
}

} // namespace

ScalarFunction ListAggregateFun::GetFunction() {
	auto result =
	    ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR}, LogicalType::ANY,
	                   ListAggregateFunction, ListAggregateBind, nullptr, nullptr, ListAggregatesInitLocalState);
	result.SetFallible();
	result.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING);
	result.varargs = LogicalType::ANY;
	result.SetSerializeCallback(ListAggregatesBindData::SerializeFunction);
	result.SetDeserializeCallback(ListAggregatesBindData::DeserializeFunction);
	return result;
}

ScalarFunction ListDistinctFun::GetFunction() {
	return ScalarFunction({LogicalType::LIST(LogicalType::TEMPLATE("T"))},
	                      LogicalType::LIST(LogicalType::TEMPLATE("T")), ListDistinctFunction,
	                      ListAggregatesBind<false>, nullptr, nullptr, ListAggregatesInitLocalState);
}

ScalarFunction ListUniqueFun::GetFunction() {
	return ScalarFunction({LogicalType::LIST(LogicalType::ANY)}, LogicalType::UBIGINT, ListUniqueFunction,
	                      ListAggregatesBind<false>, nullptr, nullptr, ListAggregatesInitLocalState);
}

} // namespace duckdb
